import os
import json
import pickle
import random

import matplotlib.pyplot as plt


def read_split_data(root: str, val_rate: float = 0.2):
    '''
    划分训练集和测试集
    输入：
        root: 数据路径
        val_rate: 划分比例
    返回:
        train_images_path       # 列表：存储训练集的所有图片【路径】
        train_images_label      # 列表：存储训练集图片对应【类别索引】信息
        val_images_path         # 列表：存储验证集的所有图片【路径】
        val_images_label        # 列表：存储验证集图片对应【类别索引】信息
    '''
    random.seed(0)  # 保证随机结果可复现
    assert os.path.exists(root), "dataset root: {} does not exist.".format(root)  # --增加监控

    # 1.遍历文件夹（一个文件夹对应一个类别），得到类别索引
    # isdir: 判断是否为文件夹，只保留文件夹名称cal（并不是路径）
    class_names = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
    # 排序，保证顺序一致
    class_names.sort()
    # 生成类别名称以及对应的数字索引（将列表转换为：名称 -> 索引）
    class_indices = dict((k, v) for v, k in enumerate(class_names))
    # 将字典(索引 -> 名称)写入json文件
    json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    train_images_path = []      # 存储训练集的所有图片路径
    train_images_label = []     # 存储训练集图片对应索引信息
    val_images_path = []        # 存储验证集的所有图片路径
    val_images_label = []       # 存储验证集图片对应索引信息
    every_class_num = []        # 存储每个类别的样本总数
    supported = [".jpg", ".JPG", ".png", ".PNG"]  # 支持的文件后缀类型
    # 2.遍历每个文件夹下的文件，同时按随机比例划分数据集
    for cla in class_names:
        # 每个文件夹的路径（一个文件夹就是一个类别）
        cla_path = os.path.join(root, cla)
        # 遍历获取supported支持的所有图片文件完整路径
        # root：总数据集文件夹的根目录
        # cla：子文件夹名
        # i：图片名称+后缀
        images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
                  if os.path.splitext(i)[-1] in supported]
        # 获取该类别对应的【索引】
        image_class = class_indices[cla]
        # 记录该类别的样本数量（该文件夹下的图片数量）
        every_class_num.append(len(images))
        # 按比例随机采样【验证样本】
        val_path = random.sample(images, k=int(len(images) * val_rate))

        for img_path in images:
            if img_path in val_path:  # 如果该路径在采样的验证集样本中则存入验证集
                val_images_path.append(img_path)
                val_images_label.append(image_class)
            else:  # 否则存入训练集
                train_images_path.append(img_path)
                train_images_label.append(image_class)

    print("{} images were found in the dataset.".format(sum(every_class_num)))
    print("{} images for training.".format(len(train_images_path)))
    print("{} images for validation.".format(len(val_images_path)))

    #plot_image = False
    plot_image = True    #---可修改：是否绘制每个类别的个数，修改为True，在return设置断点
    if plot_image:
        # 绘制每种类别个数柱状图
        plt.bar(range(len(class_names)), every_class_num, align='center')
        # 将横坐标0,1,2,3,4替换为相应的类别名称
        plt.xticks(range(len(class_names)), class_names)
        # 在柱状图上添加数值标签
        for i, v in enumerate(every_class_num):
            plt.text(x=i, y=v + 5, s=str(v), ha='center')
        # 设置x坐标
        plt.xlabel('image class')
        # 设置y坐标
        plt.ylabel('number of images')
        # 设置柱状图的标题
        # plt.title('class distribution')
        # centos中需要将图片保存下来查看
        plt.savefig("class_names.jpg")   #--centos中需要
        plt.show()

    return [train_images_path, train_images_label], [val_images_path, val_images_label], len(class_names)


def plot_data_loader_image(data_loader,
                           json_path = './class_indices.json'):
    '''
    批量绘制出DataLoader图片
    注意：保存到本地 或者 交互显示
    输入：
        data_loader
    '''
    batch_size = data_loader.batch_size
    plot_num = min(batch_size, 4)

    # 读入标签和类别对应的json文件【loader里面的label仅仅只是索引，索引对应的标签值需要json文件来对应】
    # assert os.path.exists(json_path), json_path + " does not exist."
    # 判断json文件是否存在
    is_exists_class_indices = os.path.exists(json_path)      
    if(is_exists_class_indices):
        # 文件存在就打开
        json_file = open(json_path, 'r')
        class_indices = json.load(json_file)
    else:
        print(json_path + " does not exist.")

    # 遍历DataLoader
    for data in data_loader:
        images, labels = data
        for i in range(plot_num):
            # transpose: 调整通道顺序
            # [C, H, W] -> [H, W, C] 
            img = images[i].numpy().transpose(1, 2, 0)
            # 反Normalize操作
            img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255
            # item: 取数值
            label = labels[i].item()
            plt.subplot(1, plot_num, i+1)    #行，列，当前是第几张
            # 如果有对应标签值就打印值，否则只打印索引
            if(is_exists_class_indices):
                plt.xlabel(class_indices[str(label)])
            else:
                plt.xlabel(label)

            plt.xticks([])  # 去掉x轴的刻度
            plt.yticks([])  # 去掉y轴的刻度
            plt.imshow(img.astype('uint8'))  # 将float类型转换为int类型
        
        #plt.savefig("batch_data_loader_image.jpg")   #--centos中需要
        plt.show()
        break   # ---为了不全部输出，只输出一个批量


def write_pickle(list_info: list, file_name: str):
    with open(file_name, 'wb') as f:
        pickle.dump(list_info, f)


def read_pickle(file_name: str) -> list:
    with open(file_name, 'rb') as f:
        info_list = pickle.load(f)
        return info_list
